"""
Phase 3: OECD Null Resolution
===============================
Phase 2 showed Z₁→savings collapses in OECD (127***→8).
Three competing hypotheses for why:

H1 (Income): OECD proxies for high income; it's income that attenuates.
    Test: Z + Z×income_high + Z×OECD → does OECD interaction vanish?

H2 (Openness): OECD countries are financially open; openness channels
    demographic effects differently.
    Test: Z + Z×kaopen + Z×OECD → does OECD interaction vanish?

H3 (Institutional bundle): OECD membership itself (deep markets, pension
    systems, fiscal capacity) dampens the savings response.
    Test: OECD interaction survives controlling for income + openness.

Also test the "horse race" with all three moderators simultaneously.
"""

import pandas as pd
import numpy as np
from pathlib import Path
import sys
import warnings
warnings.filterwarnings('ignore')

PROJECT_DIR = Path(__file__).resolve().parent.parent
ROOT_DIR = PROJECT_DIR.parent
sys.path.insert(0, str(ROOT_DIR / "multilateral" / "src"))
from model import PanelGLS

DATA = PROJECT_DIR / "data" / "processed"
OUT_TABLES = PROJECT_DIR / "output" / "tables"

Z_VARS = ['Z_1', 'Z_2', 'Z_3']


def stars(p):
    if p < 0.01: return '***'
    if p < 0.05: return '**'
    if p < 0.1: return '*'
    return ''


def fmt(val, p):
    return f"{val:.3f}{stars(p)}"


def run_gls(df, y_var, x_vars):
    cols = [y_var] + x_vars + ['iso3', 'year']
    sub = df[cols].dropna()
    if len(sub) < 50:
        return None
    gls = PanelGLS()
    try:
        gls.fit(sub[y_var].values, sub[x_vars].values,
                sub['iso3'].values, sub['year'].values)
    except Exception:
        return None
    result = {'n_obs': gls.n_obs, 'n_countries': gls.n_countries,
              'r_squared': gls.r_squared}
    for i, var in enumerate(x_vars):
        result[f'{var}_coef'] = gls.beta[i]
        result[f'{var}_se'] = gls.se[i]
        result[f'{var}_p'] = gls.pvalues[i]
    return result


def main():
    print("Phase 3: OECD Null Resolution")
    print("=" * 70)

    df = pd.read_csv(DATA / "unified_panel.csv")
    print(f"Panel: {len(df)} obs\n")

    # Test DVs where OECD attenuation was visible
    TEST_DVS = ['ca_gdp', 'gross_savings_gdp', 'gross_investment_gdp']

    results = []

    for dv in TEST_DVS:
        print(f"\n{'─' * 50}")
        print(f"DV: {dv}")
        print(f"{'─' * 50}")

        # Model 1: Baseline (Z only)
        m1 = run_gls(df, dv, Z_VARS)
        if m1 is None:
            continue
        print(f"  M1 (baseline): Z₁={fmt(m1['Z_1_coef'], m1['Z_1_p'])}")

        # Model 2: Z + Z×OECD
        x2 = Z_VARS + ['Z_1_x_oecd', 'Z_2_x_oecd', 'Z_3_x_oecd']
        m2 = run_gls(df, dv, x2)
        if m2:
            print(f"  M2 (+OECD):    Z₁={fmt(m2['Z_1_coef'], m2['Z_1_p'])}, "
                  f"Z₁×OECD={fmt(m2['Z_1_x_oecd_coef'], m2['Z_1_x_oecd_p'])}")

        # Model 3: H1 — Z + Z×income_high + Z×OECD
        x3 = Z_VARS + ['Z_1_x_high', 'Z_2_x_high', 'Z_3_x_high',
                        'Z_1_x_oecd', 'Z_2_x_oecd', 'Z_3_x_oecd']
        m3 = run_gls(df, dv, x3)
        if m3:
            print(f"  M3 (+High+OECD): Z₁×High={fmt(m3['Z_1_x_high_coef'], m3['Z_1_x_high_p'])}, "
                  f"Z₁×OECD={fmt(m3['Z_1_x_oecd_coef'], m3['Z_1_x_oecd_p'])}")

        # Model 4: H2 — Z + Z×kaopen + Z×OECD
        # Use continuous kaopen interaction
        x4 = Z_VARS + ['Z_1_x_kaopen', 'Z_2_x_kaopen', 'Z_3_x_kaopen',
                        'Z_1_x_oecd', 'Z_2_x_oecd', 'Z_3_x_oecd']
        m4 = run_gls(df, dv, x4)
        if m4:
            print(f"  M4 (+KAopen+OECD): Z₁×KA={fmt(m4['Z_1_x_kaopen_coef'], m4['Z_1_x_kaopen_p'])}, "
                  f"Z₁×OECD={fmt(m4['Z_1_x_oecd_coef'], m4['Z_1_x_oecd_p'])}")

        # Model 5: Horse race — Z + Z×income_high + Z×kaopen + Z×OECD
        x5 = Z_VARS + ['Z_1_x_high', 'Z_2_x_high', 'Z_3_x_high',
                        'Z_1_x_kaopen', 'Z_2_x_kaopen', 'Z_3_x_kaopen',
                        'Z_1_x_oecd', 'Z_2_x_oecd', 'Z_3_x_oecd']
        m5 = run_gls(df, dv, x5)
        if m5:
            print(f"  M5 (horse race): Z₁×High={fmt(m5['Z_1_x_high_coef'], m5['Z_1_x_high_p'])}, "
                  f"Z₁×KA={fmt(m5['Z_1_x_kaopen_coef'], m5['Z_1_x_kaopen_p'])}, "
                  f"Z₁×OECD={fmt(m5['Z_1_x_oecd_coef'], m5['Z_1_x_oecd_p'])}")

        # Store results
        models = {'M1_baseline': m1, 'M2_oecd': m2, 'M3_income_oecd': m3,
                  'M4_kaopen_oecd': m4, 'M5_horserace': m5}
        for name, m in models.items():
            if m is None:
                continue
            row = {'dv': dv, 'model': name, 'n_obs': m['n_obs'], 'r2': m['r_squared']}
            row['Z1_coef'] = m['Z_1_coef']
            row['Z1_p'] = m['Z_1_p']
            for suffix in ['oecd', 'high', 'kaopen']:
                key = f'Z_1_x_{suffix}_coef'
                if key in m:
                    row[f'Z1x{suffix}_coef'] = m[key]
                    row[f'Z1x{suffix}_p'] = m[f'Z_1_x_{suffix}_p']
            results.append(row)

    results_df = pd.DataFrame(results)

    # ── Write output table ────────────────────────────────────────────
    print(f"\n{'=' * 70}")
    print("Writing output tables...")

    with open(OUT_TABLES / "phase3_oecd_resolution.md", 'w') as f:
        f.write("# Phase 3: OECD Null Resolution — Competing Hypotheses\n\n")
        for dv in TEST_DVS:
            sub = results_df[results_df['dv'] == dv]
            if sub.empty:
                continue
            f.write(f"\n## {dv}\n\n")
            f.write("| Model | Z₁ | Z₁×OECD | Z₁×High | Z₁×KAopen | N | R² |\n")
            f.write("|---|---|---|---|---|---|---|\n")
            for _, row in sub.iterrows():
                z1 = fmt(row['Z1_coef'], row['Z1_p'])
                oecd = fmt(row.get('Z1xoecd_coef', np.nan),
                          row.get('Z1xoecd_p', 1.0)) if pd.notna(row.get('Z1xoecd_coef')) else '--'
                high = fmt(row.get('Z1xhigh_coef', np.nan),
                          row.get('Z1xhigh_p', 1.0)) if pd.notna(row.get('Z1xhigh_coef')) else '--'
                ka = fmt(row.get('Z1xkaopen_coef', np.nan),
                        row.get('Z1xkaopen_p', 1.0)) if pd.notna(row.get('Z1xkaopen_coef')) else '--'
                f.write(f"| {row['model']} | {z1} | {oecd} | {high} | "
                        f"{ka} | {int(row['n_obs'])} | {row['r2']:.3f} |\n")

        f.write("\n## Interpretation\n\n")
        f.write("- **H1 (Income)**: If Z₁×OECD loses significance when Z₁×High is added, "
                "OECD proxies for income.\n")
        f.write("- **H2 (Openness)**: If Z₁×OECD loses significance when Z₁×KAopen is added, "
                "OECD proxies for financial openness.\n")
        f.write("- **H3 (Institutional)**: If Z₁×OECD survives all controls, "
                "OECD membership itself attenuates demographic effects.\n\n")
        f.write("*PanelGLS with AR(1) correction. \\*p<0.1, \\*\\*p<0.05, \\*\\*\\*p<0.01*\n")
    print("  Wrote: phase3_oecd_resolution.md")

    # ── Verdict ───────────────────────────────────────────────────────
    # Check which hypothesis wins for savings (the key DV)
    sav = results_df[results_df['dv'] == 'gross_savings_gdp']
    if not sav.empty:
        m2_row = sav[sav['model'] == 'M2_oecd'].iloc[0] if len(sav[sav['model'] == 'M2_oecd']) > 0 else None
        m3_row = sav[sav['model'] == 'M3_income_oecd'].iloc[0] if len(sav[sav['model'] == 'M3_income_oecd']) > 0 else None
        m5_row = sav[sav['model'] == 'M5_horserace'].iloc[0] if len(sav[sav['model'] == 'M5_horserace']) > 0 else None

        print(f"\nVerdict for gross_savings_gdp:")
        if m2_row is not None:
            print(f"  M2: Z₁×OECD = {fmt(m2_row.get('Z1xoecd_coef', 0), m2_row.get('Z1xoecd_p', 1))}")
        if m3_row is not None:
            oecd_p = m3_row.get('Z1xoecd_p', 1)
            high_p = m3_row.get('Z1xhigh_p', 1)
            if pd.notna(oecd_p) and oecd_p > 0.10 and pd.notna(high_p) and high_p < 0.10:
                print("  → H1 SUPPORTED: OECD effect absorbed by income")
            elif pd.notna(oecd_p) and oecd_p < 0.10:
                print("  → H1 NOT supported: OECD survives income control")
        if m5_row is not None:
            oecd_p5 = m5_row.get('Z1xoecd_p', 1)
            if pd.notna(oecd_p5) and oecd_p5 < 0.10:
                print("  → H3 SUPPORTED: OECD survives horse race")
            else:
                print("  → H3 NOT supported: OECD absorbed in horse race")

    print(f"\nPhase 3 complete.")


if __name__ == '__main__':
    main()
